import pandas as pd
import numpy as np
import os
from scipy.integrate import cumulative_trapezoid
from scipy.interpolate import interp1d
from scipy.optimize import curve_fit

# --- CONFIGURATION ---
INPUT_FILE = 'data/Pantheon+SH0ES.dat'
# Output paths normalized to 'produced' folder
OUTPUT_FILE_PRED = 'produced/model_predictions_1701.csv'       # Apparent/Observed
OUTPUT_FILE_REFINED = 'produced/refined_actual_data_1701.csv'  # Physical/derived
PARAM_FILE = 'produced/appendix_B_fit_parameters.txt'

C_LIGHT = 299792.458
Z_STAR_MASS = 0.0003

def main():
    print('Reading raw data from ' + str(INPUT_FILE) + '...')

    # Ensure output directories exist
    os.makedirs('produced', exist_ok=True)
    os.makedirs('plots', exist_ok=True)

    # 1. READ RAW DATA
    try:
        df_raw = pd.read_csv(INPUT_FILE, sep=r'\s+', comment='#')
    except Exception as err:
        print('Error reading file: ' + str(err))
        return

    if ('zHD' not in df_raw.columns) or ('MU_SH0ES' not in df_raw.columns):
        print("CRITICAL ERROR: Columns 'zHD' or 'MU_SH0ES' not found.")
        return

    df = pd.DataFrame()
    df['z_obs'] = pd.to_numeric(df_raw['zHD'], errors='coerce')
    df['mu_obs'] = pd.to_numeric(df_raw['MU_SH0ES'], errors='coerce')
    df = df.dropna()
    df = df[df['z_obs'] > 0.001].copy()
    print('Loaded ' + str(len(df)) + ' supernovae.')

    # 2. PHYSICS STEP 1: STAR MASS CORRECTION
    df['z_clean'] = (1.0 + df['z_obs']) / (1.0 + Z_STAR_MASS) - 1.0

    # 3. PHYSICS STEP 2: ENERGY STATE INVERSION
    df['v0_derived_km_s'] = C_LIGHT * (df['z_clean'] / (1.0 + df['z_clean']))

    # 4. PHYSICS STEP 3: APPARENT DISTANCE (Geometric Optics)
    df['dL_Mpc'] = 10.0 ** ((df['mu_obs'] - 25.0) / 5.0)
    df['r_opt_Mpc'] = df['dL_Mpc'] / (1.0 + df['z_clean'])

    # 5. PHYSICS STEP 4: ITERATIVE RECTIFICATION
    print('\n--- ITERATIVE RECTIFICATION ---')
    r_grid_max = 25000 
    r_true_grid = np.linspace(0, r_grid_max, 5000)
    
    def exponential_velocity(r, K):
        return C_LIGHT * (1.0 - np.exp(-r / K))

    K_current = 4400.0 
    tolerance = 0.5
    max_iter = 10

    for i in range(max_iter):
        v_model = exponential_velocity(r_true_grid, K_current)
        n_grid = 1.0 + (v_model / C_LIGHT)**2
        d_opt_grid = cumulative_trapezoid(n_grid, r_true_grid, initial=0.0)
        
        optical_to_true = interp1d(d_opt_grid, r_true_grid, kind='linear', fill_value='extrapolate')
        r_true_rectified = optical_to_true(df['r_opt_Mpc'].values)
        
        popt, _ = curve_fit(exponential_velocity, r_true_rectified, df['v0_derived_km_s'].values, p0=[K_current])
        K_new = popt[0]
        
        diff = abs(K_new - K_current)
        print(f"Iteration {i+1}: K_in={K_current:.1f} -> K_out={K_new:.1f}")
        
        K_current = K_new
        if diff < tolerance:
            print(f"Converged! Physical Scale Radius K = {K_current:.2f} Mpc")
            df['r_true_Mpc'] = r_true_rectified
            break
    else:
        print("Warning: Max iterations reached.")
        df['r_true_Mpc'] = r_true_rectified

    # 6. EXPORT DISTINCT FILES
    
    # File 1: Apparent / Observational Data (No Refractive Physics)
    cols_pred = ['z_obs', 'mu_obs', 'dL_Mpc', 'r_opt_Mpc']
    df[cols_pred].to_csv(OUTPUT_FILE_PRED, index=False)
    print(f'Saved Apparent Data (Observed) to {OUTPUT_FILE_PRED}')

    # File 2: Refined / Physical Data (With Refractive Physics)
    # ADDED mu_obs here so plotting scripts don't fail!
    cols_refined = ['z_obs', 'mu_obs', 'z_clean', 'v0_derived_km_s', 'r_opt_Mpc', 'r_true_Mpc']
    df[cols_refined].to_csv(OUTPUT_FILE_REFINED, index=False)
    print(f'Saved Refined Data (Physical) to {OUTPUT_FILE_REFINED}')

    # 7. GENERATE PARAMETER FILE
    def sigmoidal_velocity(r, K, eta):
        return C_LIGHT * ( (r/K)**eta / (1.0 + (r/K)**eta) )

    try:
        popt_sig, _ = curve_fit(sigmoidal_velocity, df['r_true_Mpc'], df['v0_derived_km_s'], p0=[3000, 1.0])
        K_sig, eta_sig = popt_sig
    except:
        K_sig, eta_sig = 3232.0, 1.12

    report = f"""
Fit Parameters Derived from Delensed Data:

1. Exponential Fit (v = c * (1 - exp(-r/K)))
   K = {K_current:.4f} Mpc

2. Unconstrained Sigmoidal (v = c * (r/K)^eta / (1 + (r/K)^eta))
   K = {K_sig:.4f} Mpc
   eta = {eta_sig:.4f}

3. Refractive Root (Placeholder)
   K = {K_current:.4f} Mpc
"""
    with open(PARAM_FILE, 'w') as f:
        f.write(report)
    print(f"Parameters saved to {PARAM_FILE}")

if __name__ == '__main__':
    main()